# -*- coding: utf-8 -*-
from __future__ import print_function

import sys
import random
import numpy as np
import math

def _to_list(x):
    if isinstance(x, list):
        return x
    return [x]

def map(y_true, y_pred, rel_threshold=0):
    s = 0.
    y_true = _to_list(np.squeeze(y_true).tolist())
    y_pred = _to_list(np.squeeze(y_pred).tolist())
    c = list(zip(y_true, y_pred))
    random.shuffle(c)
    c = sorted(c, key=lambda x:x[1], reverse=True)
    ipos = 0
    for j, (g, p) in enumerate(c):
        if g > rel_threshold:
            ipos += 1.
            s += ipos / ( j + 1.)
    if ipos == 0:
        s = 0.
    else:
        s /= ipos
    return s

def mrr(y_true, y_pred, rel_threshold = 0.):
    k = 10
    s = 0.
    return s

def ndcg(k=10):
    def top_k(y_true, y_pred, rel_threshold=0.):
        if k <= 0.:
            return 0.
        s = 0.
        y_true = _to_list(np.squeeze(y_true).tolist())
        y_pred = _to_list(np.squeeze(y_pred).tolist())
        c = list(zip(y_true, y_pred))
        random.shuffle(c)
        c_g = sorted(c, key=lambda x:x[0], reverse=True)
        c_p = sorted(c, key=lambda x:x[1], reverse=True)
        idcg = 0.
        ndcg = 0.
        for i, (g,p) in enumerate(c_g):
            if i >= k:
                break
            if g > rel_threshold:
                idcg += (math.pow(2., g) - 1.) / math.log(2. + i)
        for i, (g,p) in enumerate(c_p):
            if i >= k:
                break
            if g > rel_threshold:
                ndcg += (math.pow(2., g) - 1.) / math.log(2. + i)
        if idcg == 0.:
            return 0.
        else:
            return ndcg / idcg
    return top_k

def precision(k=10):
    def top_k(y_true, y_pred, rel_threshold=0.):
        if k <= 0:
            return 0.
        s = 0.
        y_true = _to_list(np.squeeze(y_true).tolist())
        y_pred = _to_list(np.squeeze(y_pred).tolist())
        c = list(zip(y_true, y_pred))
        random.shuffle(c)
        c = sorted(c, key=lambda x:x[1], reverse=True)
        ipos = 0
        prec = 0.
        # print("precision", k, rel_threshold)
        # print(y_true)
        # print(y_pred)
        # print("precision", k, rel_threshold)
        for i, (g, p) in enumerate(c):
            if i >= k:
                break
            if g > rel_threshold:
                prec += 1
        prec /=  k
        return prec
    return top_k

# compute recall@k
# the input is all documents under a single query
def recall(k=10):
    def top_k(y_true, y_pred, rel_threshold=0.):
        if k <= 0:
            return 0.
        s = 0.
        y_true = _to_list(np.squeeze(y_true).tolist()) # y_true: the ground truth scores for documents under a query
        y_pred = _to_list(np.squeeze(y_pred).tolist()) # y_pred: the predicted scores for documents under a query
        pos_count = sum(i > rel_threshold for i in y_true) # total number of positive documents under this query
        c = list(zip(y_true, y_pred))
        random.shuffle(c)
        c = sorted(c, key=lambda x: x[1], reverse=True)
        ipos = 0
        recall = 0.
        # print("recall", k, rel_threshold)
        # print(y_true)
        # print(y_pred)
        # print("recall", k, rel_threshold)
        for i, (g, p) in enumerate(c):
            if i >= k:
                break
            if g > rel_threshold:
                recall += 1
        recall = (recall / pos_count) if pos_count > 0 else 0.0
        return recall
    return top_k

def mse(y_true, y_pred, rel_threshold=0.):
    s = 0.
    y_true = _to_list(np.squeeze(y_true).tolist())
    y_pred = _to_list(np.squeeze(y_pred).tolist())
    return np.mean(np.square(y_pred - y_true), axis=-1)

threshold_for_category = 0.38

def accuracy2(y_true, y_pred):
    return 1.0 * sum(~((y_pred.flatten() >= threshold_for_category) ^ (y_true == 1))) / len(y_true)

def precision2(y_true, y_pred):
    y_pred = (y_pred.flatten() >= threshold_for_category)
    y_true = (y_true == 1)
    fenmu = sum(y_pred)
    fenzi = sum(y_true & y_pred)
    return 0.0 if fenmu == 0 else fenzi * 1.0 / fenmu

def recall2(y_true, y_pred):
    y_pred = (y_pred.flatten() >= threshold_for_category)
    y_true = (y_true == 1)
    fenmu = sum(y_true)
    fenzi = sum(y_true & y_pred)
    return 0.0 if fenmu == 0 else fenzi * 1.0 / fenmu

def f1(y_true, y_pred):
    p = precision2(y_true, y_pred)
    r = recall2(y_true, y_pred)
    return 0.0 if p + r == 0.0 else 2 * p * r / (p + r)

def accuracy(y_true, y_pred):
    y_true = _to_list(np.squeeze(y_true).tolist())
    y_pred = _to_list(np.squeeze(y_pred).tolist())
    y_true_idx = np.argmax(y_true, axis = 1)
    y_pred_idx = np.argmax(y_pred, axis = 1)
    assert y_true_idx.shape == y_pred_idx.shape
    return 1.0 * np.sum(y_true_idx == y_pred_idx) / len(y_true)


def precision0(y_true, y_pred):
    y_true = _to_list(np.squeeze(y_true).tolist())
    y_pred = _to_list(np.squeeze(y_pred).tolist())
    y_true_idx = np.argmax(y_true, axis=1)
    y_pred_idx = np.argmax(y_pred, axis=1)

    fenmu = np.sum(y_pred_idx == 0)
    fenzi = np.sum((y_true_idx == 0) & (y_pred_idx == 0))
    ret = (1.0 * fenzi / fenmu) if fenmu > 0.0 else 0.0
    # if ret == 0.0 or fenzi == 0:
    #     print(y_true)
    #     print(y_pred)
    #     print(y_true_idx)
    #     print(y_pred_idx)
    #     print(fenzi, fenmu)
    #     raise "shah"
    return ret

def recall0(y_true, y_pred):
    y_true = _to_list(np.squeeze(y_true).tolist())
    y_pred = _to_list(np.squeeze(y_pred).tolist())
    y_true_idx = np.argmax(y_true, axis=1)
    y_pred_idx = np.argmax(y_pred, axis=1)

    fenmu = np.sum(y_true_idx == 0)
    fenzi = np.sum((y_true_idx == 0) & (y_pred_idx == 0))
    ret = (1.0 * fenzi / fenmu) if fenmu > 0.0 else 0.0
    return ret

def f1_0(y_true, y_pred):
    p = precision0(y_true, y_pred)
    r = recall0(y_true, y_pred)
    f1 = 0.0 if p + r == 0.0 else 2.0 * p * r / (p + r)
    # print('0', p, r, f1)
    return f1

def precision1(y_true, y_pred):
    y_true = _to_list(np.squeeze(y_true).tolist())
    y_pred = _to_list(np.squeeze(y_pred).tolist())
    y_true_idx = np.argmax(y_true, axis=1)
    y_pred_idx = np.argmax(y_pred, axis=1)

    fenmu = np.sum(y_pred_idx == 1)
    fenzi = np.sum((y_true_idx == 1) & (y_pred_idx == 1))
    return (1.0 * fenzi / fenmu) if fenmu > 0.0 else 0.0

def recall1(y_true, y_pred):
    y_true = _to_list(np.squeeze(y_true).tolist())
    y_pred = _to_list(np.squeeze(y_pred).tolist())
    y_true_idx = np.argmax(y_true, axis=1)
    y_pred_idx = np.argmax(y_pred, axis=1)

    fenmu = np.sum(y_true_idx == 1)
    fenzi = np.sum((y_true_idx == 1) & (y_pred_idx == 1))
    return (1.0 * fenzi / fenmu) if fenmu > 0.0 else 0.0

def f1_1(y_true, y_pred):
    p = precision1(y_true, y_pred)
    r = recall1(y_true, y_pred)
    f1 = 0.0 if p + r == 0.0 else 2.0 * p * r / (p + r)
    # print('1', p, r, f1)
    return f1